{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('..')\n",
    "from dataset import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "from torchvision.transforms import ToPILImage\n",
    "from torchvision.utils import make_grid\n",
    "from IPython.display import display\n",
    "from torch.utils.data import ConcatDataset\n",
    "charset = CharsetMapper('data/charset_36.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_all(dl, iter_size=None):\n",
    "    if iter_size is None: iter_size = len(dl)\n",
    "    for i, item in enumerate(dl):\n",
    "        if i >= iter_size:\n",
    "             break\n",
    "        image = item[0]\n",
    "        label = item[1][0]\n",
    "        length = item[1][1]\n",
    "        print(f'iter {i}:', [charset.get_text(label[j][0: length[j]].argmax(-1), padding=False) for j in range(bs)])\n",
    "        display(ToPILImage()(make_grid(item[0].cpu())))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Construct dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data1 = ImageDataset('data/training/ST', is_training=True);data1  # is_training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "bs=64\n",
    "data2 = ImageDataBunch.create(train_ds=data1, valid_ds=None, bs=bs, num_workers=1);data2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#data3 = data2.normalize(imagenet_stats);data3\n",
    "data3 = data2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_all(data3.train_dl, 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Add dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kwargs = {'data_aug': False, 'is_training': False}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data1 = ImageDataset('data/evaluation/IIIT5k_3000', **kwargs);data1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data2 = ImageDataset('data/evaluation/SVT', **kwargs);data2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data3 = ConcatDataset([data1, data2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=64\n",
    "data4 = ImageDataBunch.create(train_ds=data1, valid_ds=data3, bs=bs, num_workers=1);data4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(data4.train_dl), len(data4.valid_dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_all(data4.train_dl, 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TEST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(data4.valid_dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "niter = 1000\n",
    "start = time.time()\n",
    "for i, item in enumerate(progress_bar(data4.valid_dl)):\n",
    "    if i % niter == 0 and i > 0:\n",
    "        print(i, (time.time() - start) / niter)\n",
    "        start = time.time()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "num = 20\n",
    "index = 6\n",
    "plt.figure(figsize=(20, 10))\n",
    "for i in range(num):\n",
    "    plt.subplot(num // 4, 4, i+1)\n",
    "    plt.imshow(data4.train_ds[i][0].data.numpy().transpose(1,2,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show(path, image_key):\n",
    "    with lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False).begin(write=False) as txn:\n",
    "        imgbuf = txn.get(image_key.encode())  # image\n",
    "        buf = six.BytesIO()\n",
    "        buf.write(imgbuf)\n",
    "        buf.seek(0)\n",
    "        with warnings.catch_warnings():\n",
    "            warnings.simplefilter(\"ignore\", UserWarning) # EXIF warning from TiffPlugin\n",
    "            x = PIL.Image.open(buf).convert('RGB')\n",
    "        print(x.size)\n",
    "        plt.imshow(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_key = 'image-003118258'\n",
    "image_key = 'image-002780217'\n",
    "image_key = 'image-002780218'\n",
    "path = 'data/CVPR2016'\n",
    "show(path, image_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_key = 'image-004668347'\n",
    "image_key = 'image-006128516'\n",
    "path = 'data/NIPS2014'\n",
    "show(path, image_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_key = 'image-004668347'\n",
    "image_key = 'image-000002420'\n",
    "path = 'data/IIIT5K_3000'\n",
    "show(path, image_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}